import torch
from torch import nn

seq_len = 4
batch_size = 1
vocab_size = 10
hidden_size = 20
num_layers = 2  # 隐藏状态层

inputs = torch.randn((seq_len, batch_size, vocab_size))
h0 = torch.randn((num_layers, batch_size, hidden_size))
rnn = nn.RNN(vocab_size, hidden_size, num_layers=num_layers)
results, h = rnn(inputs, h0)

print(results.shape)
print(h.shape)
